[PyTorch][Fused Attn] Add support for cuDNN to return Softmax Stats always and Max when return_max_logit=True#2677
Conversation
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adapts TransformerEngine to leverage cuDNN's recent enhancement that allows returning any subset of {Stats, SumExp, Max}. The implementation now always retrieves Key changes:
All previous review feedback has been addressed, including updated comments in Confidence Score: 5/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Fused Attention Forward Pass] --> B{generate_stats = true}
B --> C[cuDNN SDPA Operation]
C --> D[Always returns Stats tensor]
D --> E{return_max_logit?}
E -->|true| F[Also set Max output]
E -->|false| G[Max = nullptr]
F --> H[Output: O, Stats, Max]
G --> I[Output: O, Stats]
H --> J[Python Layer: aux_ctx_tensors = Stats]
I --> K[Python Layer: aux_ctx_tensors = Stats]
J --> L[Compute max_logit from Max tensor]
L --> M[Return: O, aux_ctx_tensors, max_logit]
K --> N[Return: O, aux_ctx_tensors]
style D fill:#90EE90
style F fill:#FFD700
style H fill:#87CEEB
style M fill:#87CEEB
Last reviewed commit: 56e46fd |
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…27/TransformerEngine into fix_return_stats_max_cudnn
Additional Comments (1)
The public docstring still describes |
| stats = output_tensors[1] + torch.log(output_tensors[2]) | ||
| # thd: output_tensors: out [tq, h, d], Stats [tq, h, 1], Max [tq, h, 1] | ||
| # bshd: output_tensors: out [b, sq, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1] | ||
| # sbhd: output_tensors: out [sq, b, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1] (there's no typo here) |
There was a problem hiding this comment.
Do we need the "there's no typo here" :)
There was a problem hiding this comment.
I deliberately added it because I didn't believe it and checked the shapes myself :P
| size_t i = 0; | ||
| if (Aux_CTX_Tensors->size == 0) { | ||
| const auto cudnn_runtime_version = cudnnGetVersion(); | ||
|
|
There was a problem hiding this comment.
You might need to make these changes in the "Aux_CTX_Tensors->size == 0" sections in _fwd/bwd_qkvpacked/kvpacked APIs as well. Please check. Thanks!
There was a problem hiding this comment.
Looks like I don't need to because nvte_fused...qvpacked are in fused_attn.cpp which calls fused_attn_f16_arbitrary... just like regular nvte_fused_fwd/bwd
transformer_engine/common/include/transformer_engine/fused_attn.h
Outdated
Show resolved
Hide resolved
| # Max -> max_logit [h] | ||
| max_logit = torch.amax(output_tensors[1], dim=amax_dims).to(dtype=output_tensors[0].dtype) | ||
| aux_ctx_tensors = [stats] | ||
| max_logit = torch.amax(output_tensors[2], dim=amax_dims).to(dtype=output_tensors[0].dtype) |
There was a problem hiding this comment.
Maybe I understood this incorrectly, but isn't TE now also suppose to receive max from cuDNN directly (like stats, but with stats it is always true and with max it cn be toggled) rather than calling amax() in TE ?
(Sudhakar: Why am I able to update your comment? )
There was a problem hiding this comment.
cuDNN returns Max ([b, h, sq, 1]) so it's an additional softmax statistic (apparently, the subset (Stats, Max) is enough for cuDNN bwd rather than the full set (Stats, SumExp, Max)).
Further, for muon, we need do amax on it to get a dimension [h] tensor. return_max_logit in TE controls whether to fetch Max from cuDNN.
Perf wise, it'd be nice for cuDNN to do additional reduction to return the [h] shaped tensor for muon as well but that's not the scope of this PR.
(Kshitij: looks like I can as well)
…eturn_stats_max_cudnn
…27/TransformerEngine into fix_return_stats_max_cudnn
21ca43a to
becc3ad
Compare
Additional Comments (1)
|
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
d4568db to
8f40cab
Compare
…eturn_stats_max_cudnn
Description
cuDNN recently made returning any subset of {Stats, SumExp, Max} possible. This PR adapts TE to always get
Statsfrom cuDNN andMaxtensor ifreturn_max_logit=True. (Note thatStats= log(SumExp)+Max)Type of change
Changes
Please list the changes introduced in this PR:
fused_attn_f16_arbitrary_seqlen.cuSumExptensor as it's not needed since cuDNN returnsStatsby default.generate_stats=Truewhich forces cuDNN to always returnStatstensor (needed in the backward pass)transformer_engine/pytorch/cpp_extensions/fused_attn.pyStats = log(SumExp) + Maxsince cuDNN returnsStatsdirectly and TE doesn't needSumExpfrom cuDNNChecklist: